import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.rtg_configs import RTG_DEFAULT_CONFIG, RTGTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import RTG_model, RTGTrainer, MTRTG, MTRTGTrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model


@pyrallis.wrap()
def train(args: RTGTrainConfig):
    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    task_envs = [0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12]
    state_encoder_paths = [
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ]
    action_encoder_paths = [
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ]
    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    env_state_dims = [76,28,60,76,17,11,88,40,72,88,27,8,17]
    env_action_dims = [2,2,2,2,6,3,2,2,2,2,8,2,6]
    degs=[0,0,1,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,0,0,1,1,1,1,1,1]
    max_rewards=[45.0,50.0,65.0,55.0,35.0,35.0,20,15,3000,3000,2000,2000,45,50,30,30,50,35,20,15,3000,3000,250,250,3600,3600]
    max_rew_decreases=[5,10,5,5,5,5,5,3,500,500,300,300,10,10,10,10,5,5,5,3,500,500,50,50,800,800]
    min_rewards=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]

    # update config
    args.task = tasks[0]
    cfg, old_cfg = asdict(args), asdict(RTGTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(RTG_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    task_num = len(tasks)

    # setup logger
    default_cfg = asdict(RTG_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)

    if args.group is None:
        args.group = "MTRTG" + "-task_num-" + str(task_num)
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    env_ls=[]
    data_ls=[]
    for task in tasks:
        temp_env = gym.make(task)
        temp_env.set_target_cost(args.cost_limit)
        env_ls.append(temp_env)
        temp_data = temp_env.get_dataset()
        data_ls.append(temp_data)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        assert False
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    for i in range(len(tasks)):
        data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
                                    args.outliers_percent,
                                    args.noise_scale,
                                    args.inpaint_ranges,
                                    args.epsilon,
                                    args.density,
                                    cbins=cbins,
                                    rbins=rbins,
                                    max_npb=max_npb,
                                    min_npb=min_npb)

    # wrapper
    for i in range(len(tasks)):
        temp_env = env_ls[i]
        temp_env = wrap_env(
            env=temp_env,
            reward_scale=args.reward_scale,
        )
        temp_env = OfflineEnvWrapper(temp_env)
        env_ls[i] = temp_env

    state_encoder_ls = []
    pretrained_se_ls = []
    pretrained_ae_ls = []
    for i in range(task_envs[-1]+1):
        # linear only is important
        state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=args.state_encode_dim,
            hidden_sizes=args.state_encoder_hidden_sizes,
            linear_only=args.linear_only
        )
        state_encoder.to(args.device)
        state_encoder_ls.append(state_encoder)

    for i in range(len(tasks)):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True, device=torch.device("cpu"))
        pretrained_se = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        pretrained_se.load_state_dict(senc_model["model_state"])
        pretrained_se.eval()
        pretrained_se_ls.append(pretrained_se)

        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True, device=torch.device("cpu"))
            pretrained_ae = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            pretrained_ae.load_state_dict(aenc_model["model_state"])
            pretrained_ae.eval()
            pretrained_ae_ls.append(pretrained_ae)
        else:
            pretrained_ae_ls.append(None)

    enc_cfg, enc_model = load_config_and_model(args.context_encoder_path, False, device=torch.device(args.device))
    enc_cfg = types.SimpleNamespace(**enc_cfg)
    if not enc_cfg.simple_mlp:
        encoder=SafetyAwareEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+1,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim,
            simple_gate=enc_cfg.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+2,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim
            ).to(args.device)
    encoder.load_state_dict(enc_model["encoder_state"])
    encoder.eval()

    # model & optimizer & scheduler setup
    state_dim = args.state_encode_dim
    rtg_model = RTG_model(
        state_dim=state_dim,
        prompt_dim=enc_cfg.context_encoding_dim,
        cost_embedding_dim=args.embedding_dim,
        state_embedding_dim=args.embedding_dim,
        prompt_embedding_dim=args.embedding_dim,
        r_hidden_sizes=args.r_hidden_sizes,
        use_state=args.use_state,
        use_prompt=args.use_prompt
    ).to(args.device)

    model = MTRTG(rtg_model, state_encoder_ls)
    model.to(args.device)


    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = MTRTGTrainer(model,
                         logger=logger,
                         learning_rate=args.learning_rate,
                         device=args.device,
                         logprob_loss=args.logprob_loss)

    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)

    dataloader_iter_ls=[]
    prompt_dataloader_iter_ls=[]
    test_dataloader_iter_ls=[]
    for i in range(len(tasks)):
        dataset = SequenceDataset(
            data_ls[i],
            seq_len=1,
            reward_scale=args.reward_scale,
            cost_scale=args.cost_scale,
            deg=degs[i],
            pf_sample=args.pf_sample,
            max_rew_decrease=max_rew_decreases[i],
            beta=args.beta,
            augment_percent=args.augment_percent,
            cost_reverse=args.cost_reverse,
            max_reward=max_rewards[i],
            min_reward=min_rewards[i],
            pf_only=args.pf_only,
            rmin=args.rmin,
            cost_bins=args.cost_bins,
            npb=args.npb,
            cost_sample=args.cost_sample,
            cost_transform=ct,
            start_sampling=args.start_sampling,
            prob=args.prob,
            random_aug=args.random_aug,
            aug_rmin=args.aug_rmin,
            aug_rmax=args.aug_rmax,
            aug_cmin=args.aug_cmin,
            aug_cmax=args.aug_cmax,
            cgap=args.cgap,
            rstd=args.rstd,
            cstd=args.cstd
        )

        trainloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=0,
        )
        trainloader_iter = iter(trainloader)
        dataloader_iter_ls.append(trainloader_iter)

        testloader = DataLoader(
            dataset,
            batch_size=args.batch_size*100,
            pin_memory=True,
            num_workers=0,
        )
        testloader_iter = iter(testloader)
        test_dataloader_iter_ls.append(testloader_iter)

        prompt_dataset = TransitionDataset(data_ls[i],
                                        reward_scale=args.reward_scale,
                                        cost_scale=args.cost_scale,
                                        state_encoder=pretrained_se_ls[i],
                                        action_encoder=pretrained_ae_ls[i]
                                        )
        prompt_loader = DataLoader(
                                prompt_dataset,
                                batch_size=enc_cfg.context_size,
                                pin_memory=True,
                                num_workers=0,
                            )
        promptloader_iter = iter(prompt_loader)
        prompt_dataloader_iter_ls.append(promptloader_iter)

    # ready for modification
    for epoch in range(args.epoch_num):
        for step in trange(args.steps_per_epoch, desc="Training"):
            for i in range(len(tasks)):
                batch = next(dataloader_iter_ls[i])
                all_unsafe=True
                while all_unsafe:
                    prompt_batch = next(prompt_dataloader_iter_ls[i])
                    prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
                        b.to(args.device).to(torch.float32) for b in prompt_batch
                    ]
                    condition1=prompt_costs>0
                    all_unsafe=torch.all(condition1)
                states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
                    b.to(args.device) for b in batch
                ]
                with torch.no_grad():
                    encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
                    prompt_encoding = encoder(encoder_input, prompt_costs)
                    prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],-1,-1)
                returns = returns.unsqueeze(-1)
                if args.use_state:
                    trainer.train_one_step(returns, costs_return, task_names[i], task_envs[i], states=states, prompts=prompt_encoding)
                else:
                    trainer.train_one_step(returns, costs_return, task_names[i], task_envs[i], prompts=prompt_encoding)
        for i in range(len(tasks)):
            test_batch = next(test_dataloader_iter_ls[i])
            all_unsafe=True
            while all_unsafe:
                prompt_batch = next(prompt_dataloader_iter_ls[i])
                prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
                    b.to(args.device).to(torch.float32) for b in prompt_batch
                ]
                condition1=prompt_costs>0
                all_unsafe=torch.all(condition1)
            states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
                b.to(args.device) for b in test_batch
            ]
            with torch.no_grad():
                encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
                prompt_encoding = encoder(encoder_input, prompt_costs)
                prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],-1,-1)
            returns = returns.unsqueeze(-1)
            if args.use_state:
                eval_rtg_loss = trainer.eval_one_step(returns, costs_return, task_names[i], task_envs[i], states=states, prompts=prompt_encoding)
            else:
                eval_rtg_loss = trainer.eval_one_step(returns, costs_return, task_names[i], task_envs[i], prompts=prompt_encoding)
        logger.save_checkpoint()
        logger.write(epoch+1, display=False)

    # for saving the best
    # best_reward = -np.inf
    # best_cost = np.inf
    # best_idx = 0

    # for step in trange(args.update_steps, desc="Training"):
    #     batch = next(trainloader_iter)
    #     states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
    #         b.to(args.device) for b in batch
    #     ]
    #     if args.use_state:
    #         trainer.train_one_step
    #     trainer.train_one_step(states, actions, returns, costs_return, time_steps, mask,
    #                            episode_cost, costs)



if __name__ == "__main__":
    train()
